{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "dbb50720-0572-4b30-9e89-7ab7bdb2428b",
   "metadata": {},
   "source": [
    "<table align=\"left\"><td><a target=\"_blank\" href=\"https://beam.apache.org/documentation/io/built-in/webapis/\"><img src=\"https://beam.apache.org/images/logos/full-color/name-bottom/beam-logo-full-color-name-bottom-100.png\" width=\"32\" height=\"32\" />View the docs</a></td></table>"
   ]
  },
  {
   "cell_type": "raw",
   "id": "6dbfef30-d44c-482a-a122-31c07be80d77",
   "metadata": {},
   "source": [
    "#@title Licensed under the Apache License, Version 2.0 (the \"License\")\n",
    "# Licensed to the Apache Software Foundation (ASF) under one\n",
    "# or more contributor license agreements. See the NOTICE file\n",
    "# distributed with this work for additional information\n",
    "# regarding copyright ownership. The ASF licenses this file\n",
    "# to you under the Apache License, Version 2.0 (the\n",
    "# \"License\"); you may not use this file except in compliance\n",
    "# with the License. You may obtain a copy of the License at\n",
    "#\n",
    "#   http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing,\n",
    "# software distributed under the License is distributed on an\n",
    "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
    "# KIND, either express or implied. See the License for the\n",
    "# specific language governing permissions and limitations\n",
    "# under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "391810af-3391-46da-8551-0ea98a75f8f3",
   "metadata": {},
   "source": [
    "`ImageRequest` is the custom request we provide the `HttpImageClient` to invoke the HTTP call\n",
    "that acquires the image.\n",
    "\n",
    "`ImageResponse` is the custom response we return from the `HttpImageClient` that contains the image data\n",
    "as a result of calling the remote server with the image URL.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93e19710-82a2-4fdd-b1a1-f6239c08c229",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import namedtuple\n",
    "\n",
    "class ImageRequest:\n",
    "    image_url_to_mime_type = {\n",
    "        \"jpg\": \"image/jpeg\",\n",
    "        \"jpeg\": \"image/jpeg\",\n",
    "        \"png\": \"image/png\",\n",
    "    }\n",
    "\n",
    "    def __init__(self, image_url):\n",
    "        self.image_url = image_url\n",
    "        self.mime_type = self.image_url_to_mime_type.get(image_url.split(\".\")[-1])\n",
    "\n",
    "ImageResponse = namedtuple(\"ImageResponse\", [\"mime_type\", \"data\"])    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c4f6fdf-5710-44d4-9c9c-b5e3f69982be",
   "metadata": {},
   "source": [
    "#### Define Caller\n",
    "\n",
    "We implement the `Caller`, the `HttpImageClient`, that receives an `ImageRequest` and returns an `ImageResponse`.\n",
    "\n",
    "_For demo purposes, the example uses a tuple to preserve the raw URL in the returned `ImageResponse`._\n",
    "\n",
    "I/O errors are retried by the PTransform if the Caller is raising certain errors.  \n",
    "Prior to raising an exception, the transform performs a retry **for certain errors**\n",
    "using a prescribed exponential backoff. Your `Caller` must raise specific errors, to signal the transform\n",
    "to perform the retry with backoff. \n",
    "\n",
    "`RequestResponseIO` will attempt a retry with backoff when `Caller` raises:\n",
    "* UserCodeQuotaException\n",
    "* UserCodeTimeoutException\n",
    "\n",
    "After a threshold number of retries, the error is re-raised.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bac514d6-8c1f-4221-bbad-8a6a122f9369",
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "from apache_beam.io.requestresponse import (\n",
    "    Caller,\n",
    "    UserCodeExecutionException,\n",
    "    UserCodeQuotaException,\n",
    "    UserCodeTimeoutException,\n",
    ")\n",
    "\n",
    "\n",
    "class HttpImageClient(Caller):\n",
    "    STATUS_TOO_MANY_REQUESTS = 429\n",
    "    STATUS_TIMEOUT = 408\n",
    "\n",
    "    def __call__(self, kv):\n",
    "        url, request = kv\n",
    "        try:\n",
    "            response = requests.get(request.image_url)\n",
    "        except requests.exceptions.Timeout as e:\n",
    "            raise UserCodeTimeoutException() from e\n",
    "        except requests.exceptions.HTTPError as e:\n",
    "            raise UserCodeExecutionException() from e\n",
    "\n",
    "        if response.status_code >= 500:\n",
    "            raise UserCodeExecutionException()\n",
    "\n",
    "        if response.status_code >= 400:\n",
    "            match response.status_code:\n",
    "                case self.STATUS_TOO_MANY_REQUESTS:\n",
    "                    raise UserCodeQuotaException()\n",
    "                case self.STATUS_TIMEOUT:\n",
    "                    raise UserCodeTimeoutException()\n",
    "                case _:\n",
    "                    raise UserCodeExecutionException()\n",
    "\n",
    "        return url, ImageResponse(request.mime_type, response.content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0eed9274-335a-4a0b-8676-e61fb6fb1875",
   "metadata": {},
   "outputs": [],
   "source": [
    "images = [\n",
    "    \"https://storage.googleapis.com/generativeai-downloads/images/cake.jpg\",\n",
    "    \"https://storage.googleapis.com/generativeai-downloads/images/chocolate.png\",\n",
    "    \"https://storage.googleapis.com/generativeai-downloads/images/croissant.jpg\",\n",
    "    \"https://storage.googleapis.com/generativeai-downloads/images/dog_form.jpg\",\n",
    "    \"https://storage.googleapis.com/generativeai-downloads/images/factory.png\",\n",
    "    \"https://storage.googleapis.com/generativeai-downloads/images/scones.jpg\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f54d5965-9b5e-413e-94cc-a50712bfdd92",
   "metadata": {},
   "outputs": [],
   "source": [
    "import apache_beam as beam\n",
    "from apache_beam.io.requestresponse import (\n",
    "    RequestResponseIO,\n",
    ")\n",
    "from apache_beam.options.pipeline_options import PipelineOptions\n",
    "\n",
    "\n",
    "def build_image_request(image_url):\n",
    "    return image_url, ImageRequest(image_url)\n",
    "\n",
    "with beam.Pipeline(options=PipelineOptions(pickle_library=\"cloudpickle\")) as pipeline:\n",
    "    _ = (\n",
    "        pipeline\n",
    "        | \"Create data\" >> beam.Create(images)\n",
    "        | \"Map to ImageRequest\" >> beam.Map(build_image_request)\n",
    "        | \"Download image\" >> RequestResponseIO(HttpImageClient())\n",
    "        | \"Print results\"\n",
    "        >> beam.MapTuple(\n",
    "            lambda url, response: print(\n",
    "                f\"{url}, mimeType={response.mime_type}, size={len(response.data)}\"\n",
    "            )\n",
    "        )\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd76a354-09ae-41e1-8dd4-40815946e8f4",
   "metadata": {},
   "source": [
    "The last example demonstrated invoking HTTP requests directly. However, there are some API services that provide\n",
    "client code that one should use within the `Caller` implementation. Using client code within Beam presents\n",
    "unique challenges, namely serialization. Additionally, some client code requires explicit handling in terms of\n",
    "setup and teardown\n",
    "\n",
    "`RequestResponseIO` can handle such setup and teardown scenarios by overwriting context manager dunder methods \n",
    "\\_\\_enter\\_\\_ and \\_\\_exit\\_\\_ on the Caller.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b96fa3c-a8af-4d31-9249-c5b301c32632",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from google import genai\n",
    "from google.genai import types\n",
    "from google.genai.errors import APIError\n",
    "\n",
    "API_KEY = \"<your api key>\"\n",
    "\n",
    "class GeminiAIClient(Caller):\n",
    "    MODEL_GEMINI_FLASH_LITE = \"gemini-2.0-flash-lite\"\n",
    "\n",
    "    def __init__(self, api_key):\n",
    "        self.api_key = api_key\n",
    "\n",
    "    def __enter__(self):\n",
    "        self.client = genai.Client(api_key=self.api_key)\n",
    "        return self\n",
    "\n",
    "    def __call__(self, kv):\n",
    "        url, request = kv\n",
    "        try:\n",
    "            response = self.client.models.generate_content(\n",
    "                model=self.MODEL_GEMINI_FLASH_LITE,\n",
    "                contents=[\n",
    "                    types.Part.from_bytes(\n",
    "                        data=request.data,\n",
    "                        mime_type=request.mime_type,\n",
    "                    ),\n",
    "                    \"Caption this image.\",\n",
    "                ],\n",
    "            )\n",
    "        except APIError as e:\n",
    "            raise UserCodeExecutionException() from e\n",
    "\n",
    "        return url, response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e2fb1d8-691e-4366-904c-7d25435e68d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "with beam.Pipeline(options=PipelineOptions(pickle_library=\"cloudpickle\")) as pipeline:\n",
    "    _ = (\n",
    "        pipeline\n",
    "        | \"Create data\" >> beam.Create(images)\n",
    "        | \"Map to ImageRequest\" >> beam.Map(build_image_request)\n",
    "        | \"Download image\" >> RequestResponseIO(HttpImageClient())\n",
    "        | \"Gemini AI\" >> RequestResponseIO(GeminiAIClient(API_KEY))\n",
    "        | \"Print results\"\n",
    "        >> beam.MapTuple(lambda url, response: print(url, response.text))\n",
    "    )\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
